1 Reference

This is an illustration of the work by Julia Silge and Allison Horst.

Here I used a Repeated k-fold Cross validation to see if there is any improvement or difference in performance of the model in repeated folds.

I also illustrate how to tune threshold for logistic regression in order to control sensitivity and specificity of the fitted model.


Palmer Penguins

Palmer Penguins


2 Cross validation and Repetition

In a k-fold cross validation, the data is Randomly split at 5 folds, if k=5. Then we iteratively fit model in 4 train folds, and test the model on the test fold. We do this 5 times and record the performance of the model. Then we take average of performance. This way we can get a robust estimate of the performance.

Now, the data is split randomly at first, we don’t know if this split results in maximum randomness. So to get even more robust / accurate estimate we repeat this 5-fold validation several times, and take average over all repetition


5-fold Cross Validation

5-fold Cross Validation


Repeated 5-fold Cross ValidationRepeated 5-fold Cross ValidationRepeated 5-fold Cross ValidationRepeated 5-fold Cross ValidationRepeated 5-fold Cross Validation

Repeated 5-fold Cross Validation


3 Loading Packages

if (!require('palmerpenguins')) devtools::install_github("allisonhorst/palmerpenguins"); library('palmerpenguins')

require(ggplot2)
require(tidymodels)

4 Data

penguins
## # A tibble: 344 x 8
##    species island    bill_length_mm bill_depth_mm flipper_length_mm body_mass_g
##    <fct>   <fct>              <dbl>         <dbl>             <int>       <int>
##  1 Adelie  Torgersen           39.1          18.7               181        3750
##  2 Adelie  Torgersen           39.5          17.4               186        3800
##  3 Adelie  Torgersen           40.3          18                 195        3250
##  4 Adelie  Torgersen           NA            NA                  NA          NA
##  5 Adelie  Torgersen           36.7          19.3               193        3450
##  6 Adelie  Torgersen           39.3          20.6               190        3650
##  7 Adelie  Torgersen           38.9          17.8               181        3625
##  8 Adelie  Torgersen           39.2          19.6               195        4675
##  9 Adelie  Torgersen           34.1          18.1               193        3475
## 10 Adelie  Torgersen           42            20.2               190        4250
## # ... with 334 more rows, and 2 more variables: sex <fct>, year <int>
glimpse(penguins)
## Rows: 344
## Columns: 8
## $ species           <fct> Adelie, Adelie, Adelie, Adelie, Adelie, Adelie, Adel~
## $ island            <fct> Torgersen, Torgersen, Torgersen, Torgersen, Torgerse~
## $ bill_length_mm    <dbl> 39.1, 39.5, 40.3, NA, 36.7, 39.3, 38.9, 39.2, 34.1, ~
## $ bill_depth_mm     <dbl> 18.7, 17.4, 18.0, NA, 19.3, 20.6, 17.8, 19.6, 18.1, ~
## $ flipper_length_mm <int> 181, 186, 195, NA, 193, 190, 181, 195, 193, 190, 186~
## $ body_mass_g       <int> 3750, 3800, 3250, NA, 3450, 3650, 3625, 4675, 3475, ~
## $ sex               <fct> male, female, female, NA, female, male, female, male~
## $ year              <int> 2007, 2007, 2007, 2007, 2007, 2007, 2007, 2007, 2007~

Today we will build a logistic regression model to predict the Gender of the palmer penguin, using body part lengths and body mass data.

5 Plot the Data

Plot the data by flipper length and bill length of the penguins, also use body mass in the bubble plot to see for any relationship with the gender of the penguins.

penguins %>%
  filter(!is.na(sex)) %>%
  ggplot(aes(flipper_length_mm, bill_length_mm, color = sex, size = body_mass_g)) +
  geom_point(alpha = 0.5) +
  facet_wrap(~species)

penguins %>%
  drop_na() %>% 
  select(species, body_mass_g, ends_with("_mm"), sex) %>%
  GGally::ggpairs(aes(color = sex, alpha = 0.5))

Looks like there is relation between sex, body part lengths and body mass of the penguins.


Body Parts of Penguins

Body Parts of Penguins



6 Modelling Logistic Regression

6.1 Data for modelling

We drop year and island from our data. Also drop missing observations in sex

penguins_df <- penguins %>%
  filter(!is.na(sex)) %>%
  select(-year, -island)
penguins_df
## # A tibble: 333 x 6
##    species bill_length_mm bill_depth_mm flipper_length_mm body_mass_g sex   
##    <fct>            <dbl>         <dbl>             <int>       <int> <fct> 
##  1 Adelie            39.1          18.7               181        3750 male  
##  2 Adelie            39.5          17.4               186        3800 female
##  3 Adelie            40.3          18                 195        3250 female
##  4 Adelie            36.7          19.3               193        3450 female
##  5 Adelie            39.3          20.6               190        3650 male  
##  6 Adelie            38.9          17.8               181        3625 female
##  7 Adelie            39.2          19.6               195        4675 male  
##  8 Adelie            41.1          17.6               182        3200 female
##  9 Adelie            38.6          21.2               191        3800 male  
## 10 Adelie            34.6          21.1               198        4400 male  
## # ... with 323 more rows
levels(penguins_df$sex)
## [1] "female" "male"

6.2 Splitting Into Training and Testing Data

We use tidymodels package for modelling gender of the palmer penguins.

set.seed(123)
penguin_split <- initial_split(penguins_df, strata = sex)
penguin_train <- training(penguin_split)
penguin_test <- testing(penguin_split)
penguin_split
## <Analysis/Assess/Total>
## <250/83/333>

6.3 Create Repeated 10-fold Cross Validation Dataset

penguin_cv <- vfold_cv(data = penguin_train, v = 10, repeats = 10, strata = sex)
penguin_cv
## #  10-fold cross-validation repeated 10 times using stratification 
## # A tibble: 100 x 3
##    splits           id       id2   
##    <list>           <chr>    <chr> 
##  1 <split [224/26]> Repeat01 Fold01
##  2 <split [224/26]> Repeat01 Fold02
##  3 <split [224/26]> Repeat01 Fold03
##  4 <split [224/26]> Repeat01 Fold04
##  5 <split [225/25]> Repeat01 Fold05
##  6 <split [225/25]> Repeat01 Fold06
##  7 <split [226/24]> Repeat01 Fold07
##  8 <split [226/24]> Repeat01 Fold08
##  9 <split [226/24]> Repeat01 Fold09
## 10 <split [226/24]> Repeat01 Fold10
## # ... with 90 more rows

6.4 Specify Model

glm_spec <- logistic_reg() %>%
  set_engine("glm")

6.5 Specify Workflow

penguin_wf <- workflow() %>%
  add_formula(sex ~ .)

6.6 Fit the Logistic Model in CV datasets

fit_resamples fits the logistic model in each of the 100 training datasets in the penguin_cv set, and evaluates the model on each of the 100 testing datasets. It also saves the predictions for evaluating performance of the model on each dataset.

### Parallel Processing makes things faster
### tidymodels support parallel processing

doParallel::registerDoParallel()

glm_rs <- penguin_wf %>%
  add_model(glm_spec) %>%
  fit_resamples(
    resamples = penguin_cv,
    control = control_resamples(save_pred = TRUE, verbose = TRUE)
  )

glm_rs
## # Resampling results
## # 10-fold cross-validation repeated 10 times using stratification 
## # A tibble: 100 x 6
##    splits          id       id2   .metrics       .notes         .predictions    
##    <list>          <chr>    <chr> <list>         <list>         <list>          
##  1 <split [224/26~ Repeat01 Fold~ <tibble [2 x ~ <tibble [0 x ~ <tibble [26 x 6~
##  2 <split [224/26~ Repeat01 Fold~ <tibble [2 x ~ <tibble [0 x ~ <tibble [26 x 6~
##  3 <split [224/26~ Repeat01 Fold~ <tibble [2 x ~ <tibble [0 x ~ <tibble [26 x 6~
##  4 <split [224/26~ Repeat01 Fold~ <tibble [2 x ~ <tibble [0 x ~ <tibble [26 x 6~
##  5 <split [225/25~ Repeat01 Fold~ <tibble [2 x ~ <tibble [0 x ~ <tibble [25 x 6~
##  6 <split [225/25~ Repeat01 Fold~ <tibble [2 x ~ <tibble [0 x ~ <tibble [25 x 6~
##  7 <split [226/24~ Repeat01 Fold~ <tibble [2 x ~ <tibble [0 x ~ <tibble [24 x 6~
##  8 <split [226/24~ Repeat01 Fold~ <tibble [2 x ~ <tibble [0 x ~ <tibble [24 x 6~
##  9 <split [226/24~ Repeat01 Fold~ <tibble [2 x ~ <tibble [0 x ~ <tibble [24 x 6~
## 10 <split [226/24~ Repeat01 Fold~ <tibble [2 x ~ <tibble [0 x ~ <tibble [24 x 6~
## # ... with 90 more rows

6.7 Check Model Accuracy

This accuracy and AUC is mean over all CV dataset.

collect_metrics(glm_rs)
## # A tibble: 2 x 6
##   .metric  .estimator  mean     n std_err .config             
##   <chr>    <chr>      <dbl> <int>   <dbl> <chr>               
## 1 accuracy binary     0.900   100 0.00589 Preprocessor1_Model1
## 2 roc_auc  binary     0.965   100 0.00306 Preprocessor1_Model1
glm_rs %>% 
  unnest(.metrics) %>% 
  ggplot(aes(id2, .estimate, color = .metric)) + 
  geom_point() +
  labs(title = "Accurary and ACU over Folds and Repetitions",
       x = "Fold", 
       y = NULL,
       color = "Metric") +
  facet_wrap(.metric ~ id) +
  theme(axis.text.x = element_text(size=6, angle = 90)) 

6.8 Confusion Matrix

Also showing average numbers in the confusion matrix.

glm_rs %>%
  conf_mat_resampled()
## # A tibble: 4 x 3
##   Prediction Truth   Freq
##   <fct>      <fct>  <dbl>
## 1 female     female 112. 
## 2 female     male    12.7
## 3 male       female  12.2
## 4 male       male   113.

6.9 ROC curves over folds and repeats

The ROC curve shows similar performance over repeats, although some variation is seen over the folds within repeats.

glm_rs %>%
  collect_predictions() %>%
  group_by(id, id2) %>%
  roc_curve(sex, .pred_female) %>%
  ggplot(aes(1 - specificity, sensitivity, color = id2)) +
  geom_abline(lty = 2, color = "gray80", size = 1.5) +
  geom_path(show.legend = TRUE, alpha = 0.5, size = 0.8) +
  coord_equal() +
  facet_wrap(~id) + 
  labs(color='Fold', x = "1 - Specificity", y = "Sensitivity", title = "ROC & AUC by Fold and Repeat") +
  theme_minimal()

6.10 Finalize model on Whole Training Data and Test on Testing data

penguin_final <- penguin_wf %>%
  add_model(glm_spec) %>%
  last_fit(penguin_split)

penguin_final
## # Resampling results
## # Manual resampling 
## # A tibble: 1 x 6
##   splits       id           .metrics      .notes       .predictions    .workflow
##   <list>       <chr>        <list>        <list>       <list>          <list>   
## 1 <split [250~ train/test ~ <tibble [2 x~ <tibble [0 ~ <tibble [83 x ~ <workflo~

6.11 Metrics over Test Data

The metrics on testing data shows similar performance with the CV data. This indicated absence of overfitting, and good predictive performance of the logistic model for new data.

collect_metrics(penguin_final)
## # A tibble: 2 x 4
##   .metric  .estimator .estimate .config             
##   <chr>    <chr>          <dbl> <chr>               
## 1 accuracy binary         0.940 Preprocessor1_Model1
## 2 roc_auc  binary         0.991 Preprocessor1_Model1
collect_predictions(penguin_final) %>%
  conf_mat(sex, .pred_class)
##           Truth
## Prediction female male
##     female     39    3
##     male        2   39
collect_predictions(penguin_final) %>%
  sensitivity(sex, .pred_class)
## # A tibble: 1 x 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 sens    binary         0.951
collect_predictions(penguin_final) %>%
  specificity(sex, .pred_class)
## # A tibble: 1 x 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 spec    binary         0.929
collect_predictions(penguin_final) %>%
  precision(sex, .pred_class)
## # A tibble: 1 x 3
##   .metric   .estimator .estimate
##   <chr>     <chr>          <dbl>
## 1 precision binary         0.929

6.12 Tuning Threshold Probability for Classifying Female Penguin

The roc_curve function constructs the full ROC curve using threshold values and returns a tibble.

### First collect predictions on the test data from the model
penguin_final %>%
  collect_predictions()
## # A tibble: 83 x 7
##    id            .pred_female .pred_male  .row .pred_class sex   .config        
##    <chr>                <dbl>      <dbl> <int> <fct>       <fct> <chr>          
##  1 train/test s~    0.0117       0.988       9 male        male  Preprocessor1_~
##  2 train/test s~    0.499        0.501      12 male        fema~ Preprocessor1_~
##  3 train/test s~    0.0000458    1.00       13 male        male  Preprocessor1_~
##  4 train/test s~    0.985        0.0155     24 female      fema~ Preprocessor1_~
##  5 train/test s~    0.992        0.00754    26 female      fema~ Preprocessor1_~
##  6 train/test s~    0.729        0.271      27 female      male  Preprocessor1_~
##  7 train/test s~    0.0262       0.974      32 male        male  Preprocessor1_~
##  8 train/test s~    0.415        0.585      42 male        male  Preprocessor1_~
##  9 train/test s~    0.000236     1.00       44 male        male  Preprocessor1_~
## 10 train/test s~    0.835        0.165      45 female      fema~ Preprocessor1_~
## # ... with 73 more rows
### Then construct roc curve with these predictions
penguin_final %>%
  collect_predictions() %>%
  roc_curve(sex, .pred_female)
## # A tibble: 85 x 3
##    .threshold specificity sensitivity
##         <dbl>       <dbl>       <dbl>
##  1 -Inf            0                1
##  2    1.24e-7      0                1
##  3    1.24e-5      0.0238           1
##  4    1.78e-5      0.0476           1
##  5    4.58e-5      0.0714           1
##  6    7.60e-5      0.0952           1
##  7    1.12e-4      0.119            1
##  8    1.18e-4      0.143            1
##  9    1.49e-4      0.167            1
## 10    2.36e-4      0.190            1
## # ... with 75 more rows

Using this tibble we can make the ROC curve and find optimal threshold for classifying female penguins.

### Hover your mouse over this plot

r <- penguin_final %>%
  collect_predictions() %>%
  roc_curve(sex, .pred_female) %>% 
  ggplot(aes(1 - specificity, sensitivity)) +
  geom_point(size = 0.2, aes(color = .threshold)) +
  geom_abline(lty = 2, 
              color = "gray80", 
              size = 1.5) +
  geom_path(show.legend = TRUE, 
            alpha = 0.3, 
            size = 0.5) +
  geom_text(aes(label = round(.threshold, 2)), 
            size = 2.5, 
            vjust = -0.5, 
            fontface = "bold")
plotly::ggplotly(r)

The ROC curve shows threshold = 0.75 gives 100% Specificity. We can change the predictions made by the model by changing the threshold value to 0.75, in order to predict males more accurately, with a little error to predicting females.

We can use the probably package to change the threshold and make new predictions. Check HERE for more !

### We need probably package
library(probably)

### set threshold
thresh <- 0.75

### For more information,
### run ?roc_curve
### run ?make_two_class_pred

### Mutate .pred_class with new threshold

new_preds <- penguin_final %>%
  collect_predictions()  %>%
  ### mutate .pred_class with new threshold
  mutate(.pred_class = make_two_class_pred(.pred_female, ### Predicted Probability
                                           levels(sex),
                                           threshold = thresh), ### Threshold
         .pred_class = factor(.pred_class, levels = levels(sex)))

### With New Threshold, Performance on Test Data

### Confusion Matrix 
new_preds %>%
  conf_mat(sex, .pred_class)
##           Truth
## Prediction female male
##     female     38    0
##     male        3   42
### Sensitivity
new_preds %>%
  sensitivity(sex, .pred_class)
## # A tibble: 1 x 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 sens    binary         0.927
### Specificity
new_preds %>%
  specificity(sex, .pred_class)
## # A tibble: 1 x 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 spec    binary             1
### Precision
new_preds %>%
  precision(sex, .pred_class)
## # A tibble: 1 x 3
##   .metric   .estimator .estimate
##   <chr>     <chr>          <dbl>
## 1 precision binary             1

We can see the new threshold performs better in terms of specificity. This could be useful if the modeler (say a biologist) need to classify males more accurately than females for scientific purposes. We can also set threshold = 0.17 for 100% sensitivity, if we want to classify females more accurately than males.

By default tidymodels predicts with threshold = 0.50 for logistic regression.


6.13 Odds Estimates and Variables

Looks like bill depth and bill length have highest importance in predicting gender of the penguins. These two variables separate the penguins by gender most.

penguin_final$.workflow[[1]] %>%
  tidy(exponentiate = TRUE)
## # A tibble: 7 x 5
##   term              estimate std.error statistic       p.value
##   <chr>                <dbl>     <dbl>     <dbl>         <dbl>
## 1 (Intercept)       3.12e-35  13.5         -5.90 0.00000000369
## 2 speciesChinstrap  1.34e- 3   1.70        -3.89 0.000101     
## 3 speciesGentoo     1.08e- 4   2.89        -3.16 0.00159      
## 4 bill_length_mm    1.78e+ 0   0.137        4.20 0.0000268    
## 5 bill_depth_mm     3.89e+ 0   0.373        3.64 0.000273     
## 6 flipper_length_mm 1.07e+ 0   0.0538       1.31 0.189        
## 7 body_mass_g       1.01e+ 0   0.00108      4.70 0.00000260
penguin_final$.workflow[[1]] %>%
  tidy(exponentiate = TRUE) %>% 
  select(term, estimate) %>% 
  mutate(term = as.factor(term)) %>% 
  ggplot(aes(reorder(term, estimate), estimate, 
             fill =term)) +
  geom_bar(stat = "identity", show.legend = FALSE, width = 0.7) + 
  labs(title = "Increase in odds of penguin being Female by one unit increase in each variable",
       x = "Variable", y = "Odds increase by Times") +
  geom_text(aes(label = round(estimate, 3)), 
            nudge_y = 0.15 , 
            size = 4.5, 
            colour = 'black', 
            fontface = 'bold') +
  coord_flip() +
  theme_light()

1mm increase in bill depth increases the odds of the penguin being female by almost 4 times. Species don’t seem to affect the prediction of gender of the penguins much.


7 Bill Depth vs Bill Length Scatter

penguins %>%
  filter(!is.na(sex)) %>%
  ggplot(aes(bill_depth_mm, bill_length_mm, color = sex, size = body_mass_g)) +
  geom_point(alpha = 0.5) +
  facet_wrap(~species)


Bill Length and Bill Depth

Bill Length and Bill Depth